AllLife Bank is a US bank that has a growing customer base. The majority of these customers are liability customers (depositors) with varying sizes of deposits. The number of customers who are also borrowers (asset customers) is quite small, and the bank is interested in expanding this base rapidly to bring in more loan business and in the process, earn more through the interest on loans. In particular, the management wants to explore ways of converting its liability customers to personal loan customers (while retaining them as depositors).
A campaign that the bank ran last year for liability customers showed a healthy conversion rate of over 9% success. This has encouraged the retail marketing department to devise campaigns with better target marketing to increase the success ratio.
You as a Data scientist at AllLife bank have to build a model that will help the marketing department to identify the potential customers who have a higher probability of purchasing the loan.
To predict whether a liability customer will buy personal loans, to understand which customer attributes are most significant in driving purchases, and identify which segment of customers to target more.
ID: Customer IDAge: Customer’s age in completed yearsExperience: #years of professional experienceIncome: Annual income of the customer (in thousand dollars)ZIP Code: Home Address ZIP code.Family: the Family size of the customerCCAvg: Average spending on credit cards per month (in thousand dollars)Education: Education Level. 1: Undergrad; 2: Graduate;3: Advanced/ProfessionalMortgage: Value of house mortgage if any. (in thousand dollars)Personal_Loan: Did this customer accept the personal loan offered in the last campaign? (0: No, 1: Yes)Securities_Account: Does the customer have securities account with the bank? (0: No, 1: Yes)CD_Account: Does the customer have a certificate of deposit (CD) account with the bank? (0: No, 1: Yes)Online: Do customers use internet banking facilities? (0: No, 1: Yes)CreditCard: Does the customer use a credit card issued by any other Bank (excluding All life Bank)? (0: No, 1: Yes)# Installing the libraries with the specified version.
#!pip install numpy==1.25.2 pandas==1.5.3 matplotlib==3.7.1 seaborn==0.13.1 scikit-learn==1.2.2 sklearn-pandas==2.2.0 -q --user
Note: Above Import statments are causing conflicts with Pandas and US Zip code so used latest libraries in Colab .
#Google colab is having library conflicts for nb_black
#%load_ext nb_black
# Library to suppress warnings or deprecation notes
import warnings
warnings.filterwarnings("ignore")
# Libraries to help with reading and manipulating data
import pandas as pd
import numpy as np
# Library to split data
from sklearn.model_selection import train_test_split
# libaries to help with data visualization
import matplotlib.pyplot as plt
import seaborn as sns
# Removes the limit for the number of displayed columns
pd.set_option("display.max_columns", None)
# Sets the limit for the number of displayed rows
pd.set_option("display.max_rows", 200)
# Libraries to build decision tree classifier
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
# To tune different models
from sklearn.model_selection import GridSearchCV
# To perform statistical analysis
import scipy.stats as stats
# To get diferent metric scores
from sklearn.metrics import (
f1_score,
accuracy_score,
recall_score,
precision_score,
confusion_matrix,
ConfusionMatrixDisplay,
make_scorer,
)
#mount goolge drive
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
# loading the dataset
data = pd.read_csv("/content/drive/MyDrive/Colab Notebooks/machinelearning/Module project/Loan_Modelling.csv")
Understand the shape of the dataset.
data.shape
(5000, 14)
View the first and last 5 rows of the dataset.
data.head()
| ID | Age | Experience | Income | ZIPCode | Family | CCAvg | Education | Mortgage | Personal_Loan | Securities_Account | CD_Account | Online | CreditCard | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 25 | 1 | 49 | 91107 | 4 | 1.6 | 1 | 0 | 0 | 1 | 0 | 0 | 0 |
| 1 | 2 | 45 | 19 | 34 | 90089 | 3 | 1.5 | 1 | 0 | 0 | 1 | 0 | 0 | 0 |
| 2 | 3 | 39 | 15 | 11 | 94720 | 1 | 1.0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
| 3 | 4 | 35 | 9 | 100 | 94112 | 1 | 2.7 | 2 | 0 | 0 | 0 | 0 | 0 | 0 |
| 4 | 5 | 35 | 8 | 45 | 91330 | 4 | 1.0 | 2 | 0 | 0 | 0 | 0 | 0 | 1 |
data.tail()
| ID | Age | Experience | Income | ZIPCode | Family | CCAvg | Education | Mortgage | Personal_Loan | Securities_Account | CD_Account | Online | CreditCard | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 4995 | 4996 | 29 | 3 | 40 | 92697 | 1 | 1.9 | 3 | 0 | 0 | 0 | 0 | 1 | 0 |
| 4996 | 4997 | 30 | 4 | 15 | 92037 | 4 | 0.4 | 1 | 85 | 0 | 0 | 0 | 1 | 0 |
| 4997 | 4998 | 63 | 39 | 24 | 93023 | 2 | 0.3 | 3 | 0 | 0 | 0 | 0 | 0 | 0 |
| 4998 | 4999 | 65 | 40 | 49 | 90034 | 3 | 0.5 | 2 | 0 | 0 | 0 | 0 | 1 | 0 |
| 4999 | 5000 | 28 | 4 | 83 | 92612 | 3 | 0.8 | 1 | 0 | 0 | 0 | 0 | 1 | 1 |
# let's look at the structure of the data
data.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 5000 entries, 0 to 4999 Data columns (total 14 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 ID 5000 non-null int64 1 Age 5000 non-null int64 2 Experience 5000 non-null int64 3 Income 5000 non-null int64 4 ZIPCode 5000 non-null int64 5 Family 5000 non-null int64 6 CCAvg 5000 non-null float64 7 Education 5000 non-null int64 8 Mortgage 5000 non-null int64 9 Personal_Loan 5000 non-null int64 10 Securities_Account 5000 non-null int64 11 CD_Account 5000 non-null int64 12 Online 5000 non-null int64 13 CreditCard 5000 non-null int64 dtypes: float64(1), int64(13) memory usage: 547.0 KB
Observations -
All the variables are Integers except CCAvg which is a float. Some of the variables have to be converted to Categorical variable for analysis.
data.describe(include="all").T
| count | mean | std | min | 25% | 50% | 75% | max | |
|---|---|---|---|---|---|---|---|---|
| ID | 5000.0 | 2500.500000 | 1443.520003 | 1.0 | 1250.75 | 2500.5 | 3750.25 | 5000.0 |
| Age | 5000.0 | 45.338400 | 11.463166 | 23.0 | 35.00 | 45.0 | 55.00 | 67.0 |
| Experience | 5000.0 | 20.104600 | 11.467954 | -3.0 | 10.00 | 20.0 | 30.00 | 43.0 |
| Income | 5000.0 | 73.774200 | 46.033729 | 8.0 | 39.00 | 64.0 | 98.00 | 224.0 |
| ZIPCode | 5000.0 | 93169.257000 | 1759.455086 | 90005.0 | 91911.00 | 93437.0 | 94608.00 | 96651.0 |
| Family | 5000.0 | 2.396400 | 1.147663 | 1.0 | 1.00 | 2.0 | 3.00 | 4.0 |
| CCAvg | 5000.0 | 1.937938 | 1.747659 | 0.0 | 0.70 | 1.5 | 2.50 | 10.0 |
| Education | 5000.0 | 1.881000 | 0.839869 | 1.0 | 1.00 | 2.0 | 3.00 | 3.0 |
| Mortgage | 5000.0 | 56.498800 | 101.713802 | 0.0 | 0.00 | 0.0 | 101.00 | 635.0 |
| Personal_Loan | 5000.0 | 0.096000 | 0.294621 | 0.0 | 0.00 | 0.0 | 0.00 | 1.0 |
| Securities_Account | 5000.0 | 0.104400 | 0.305809 | 0.0 | 0.00 | 0.0 | 0.00 | 1.0 |
| CD_Account | 5000.0 | 0.060400 | 0.238250 | 0.0 | 0.00 | 0.0 | 0.00 | 1.0 |
| Online | 5000.0 | 0.596800 | 0.490589 | 0.0 | 0.00 | 1.0 | 1.00 | 1.0 |
| CreditCard | 5000.0 | 0.294000 | 0.455637 | 0.0 | 0.00 | 0.0 | 1.00 | 1.0 |
Observations
# checking for unique values in ID column
data["ID"].nunique()
5000
Since all the values in ID column are unique we can drop it
#Drop the ID column
data.drop(["ID"], axis=1, inplace=True)
#Check for nulls in dataset
data.isnull().sum()
Age 0 Experience 0 Income 0 ZIPCode 0 Family 0 CCAvg 0 Education 0 Mortgage 0 Personal_Loan 0 Securities_Account 0 CD_Account 0 Online 0 CreditCard 0 dtype: int64
# function to create labeled barplots
def labeled_barplot(data, feature, perc=False, n=None):
"""
Barplot with percentage at the top
data: dataframe
feature: dataframe column
perc: whether to display percentages instead of count (default is False)
n: displays the top n category levels (default is None, i.e., display all levels)
"""
total = len(data[feature]) # length of the column
count = data[feature].nunique()
if n is None:
plt.figure(figsize=(count + 2, 6))
else:
plt.figure(figsize=(n + 2, 6))
plt.xticks(rotation=90, fontsize=15)
ax = sns.countplot(
data=data,
x=feature,
palette="Paired",
order=data[feature].value_counts().index[:n],
)
for p in ax.patches:
if perc == True:
label = "{:.1f}%".format(
100 * p.get_height() / total
) # percentage of each class of the category
else:
label = p.get_height() # count of each level of the category
x = p.get_x() + p.get_width() / 2 # width of the plot
y = p.get_height() # height of the plot
ax.annotate(
label,
(x, y),
ha="center",
va="center",
size=12,
xytext=(0, 5),
textcoords="offset points",
) # annotate the percentage
plt.show() # show the plot
# function to plot stacked bar chart
def stacked_barplot(data, predictor, target):
"""
Print the category counts and plot a stacked bar chart
data: dataframe
predictor: independent variable
target: target variable
"""
count = data[predictor].nunique()
sorter = data[target].value_counts().index[-1]
tab1 = pd.crosstab(data[predictor], data[target], margins=True).sort_values(
by=sorter, ascending=False
)
print(tab1)
print("-" * 120)
tab = pd.crosstab(data[predictor], data[target], normalize="index").sort_values(
by=sorter, ascending=False
)
tab.plot(kind="bar", stacked=True, figsize=(count + 5, 5))
plt.legend(
loc="lower left", frameon=False,
)
plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.show()
def histogram_boxplot(data, feature, figsize=(12, 7), kde=False, bins=None):
"""
Boxplot and histogram combined
data: dataframe
feature: dataframe column
figsize: size of figure (default (12,7))
kde: whether to show the density curve (default False)
bins: number of bins for histogram (default None)
"""
f2, (ax_box2, ax_hist2) = plt.subplots(
nrows=2, # Number of rows of the subplot grid= 2
sharex=True, # x-axis will be shared among all subplots
gridspec_kw={"height_ratios": (0.25, 0.75)},
figsize=figsize,
) # creating the 2 subplots
sns.boxplot(
data=data, x=feature, ax=ax_box2, showmeans=True, color="violet"
) # boxplot will be created and a star will indicate the mean value of the column
sns.histplot(
data=data, x=feature, kde=kde, ax=ax_hist2, bins=bins, palette="winter"
) if bins else sns.histplot(
data=data, x=feature, kde=kde, ax=ax_hist2
) # For histogram
ax_hist2.axvline(
data[feature].mean(), color="green", linestyle="--"
) # Add mean to the histogram
ax_hist2.axvline(
data[feature].median(), color="black", linestyle="-"
) # Add median to the histogram
histogram_boxplot(data, "Age")
histogram_boxplot(data, "Experience")
histogram_boxplot(data, "Income")
histogram_boxplot(data, "Family")
histogram_boxplot(data, "CCAvg")
histogram_boxplot(data, "Education")
histogram_boxplot(data, "Mortgage")
#filter out no Mortgage information to understand data better
data_m = data[data['Mortgage']!=0]
histogram_boxplot(data_m, "Mortgage")
histogram_boxplot(data, "Personal_Loan")
histogram_boxplot(data, "Securities_Account")
histogram_boxplot(data, "CD_Account")
histogram_boxplot(data, "Online")
histogram_boxplot(data, "CreditCard")
### function to plot distributions wrt target
def distribution_plot_wrt_target(data, predictor, target):
fig, axs = plt.subplots(2, 2, figsize=(12, 10))
target_uniq = data[target].unique()
axs[0, 0].set_title("Distribution of target for target=" + str(target_uniq[0]))
sns.histplot(
data=data[data[target] == target_uniq[0]],
x=predictor,
kde=True,
ax=axs[0, 0],
color="teal",
stat="density",
)
axs[0, 1].set_title("Distribution of target for target=" + str(target_uniq[1]))
sns.histplot(
data=data[data[target] == target_uniq[1]],
x=predictor,
kde=True,
ax=axs[0, 1],
color="orange",
stat="density",
)
axs[1, 0].set_title("Boxplot w.r.t target")
sns.boxplot(data=data, x=target, y=predictor, ax=axs[1, 0], palette="gist_rainbow")
axs[1, 1].set_title("Boxplot (without outliers) w.r.t target")
sns.boxplot(
data=data,
x=target,
y=predictor,
ax=axs[1, 1],
showfliers=False,
palette="gist_rainbow",
)
plt.tight_layout()
plt.show()
stacked_barplot(data, "Age","Personal_Loan")
Personal_Loan 0 1 All Age All 4520 480 5000 34 116 18 134 30 119 17 136 36 91 16 107 63 92 16 108 35 135 16 151 33 105 15 120 52 130 15 145 29 108 15 123 54 128 15 143 43 134 15 149 42 112 14 126 56 121 14 135 65 66 14 80 44 107 14 121 50 125 13 138 45 114 13 127 46 114 13 127 26 65 13 78 32 108 12 120 57 120 12 132 38 103 12 115 27 79 12 91 48 106 12 118 61 110 12 122 53 101 11 112 51 119 10 129 60 117 10 127 58 133 10 143 49 105 10 115 47 103 10 113 59 123 9 132 28 94 9 103 62 114 9 123 55 116 9 125 64 70 8 78 41 128 8 136 40 117 8 125 37 98 8 106 31 118 7 125 39 127 6 133 24 28 0 28 25 53 0 53 66 24 0 24 67 12 0 12 23 12 0 12 ------------------------------------------------------------------------------------------------------------------------
distribution_plot_wrt_target(data, "Age", "Personal_Loan")
stacked_barplot(data, "Education","Personal_Loan")
Personal_Loan 0 1 All Education All 4520 480 5000 3 1296 205 1501 2 1221 182 1403 1 2003 93 2096 ------------------------------------------------------------------------------------------------------------------------
distribution_plot_wrt_target(data, "Education", "Personal_Loan")
stacked_barplot(data, "Income", "Personal_Loan")
Personal_Loan 0 1 All Income All 4520 480 5000 130 8 11 19 182 2 11 13 158 8 10 18 135 8 10 18 179 8 9 17 141 15 9 24 154 12 9 21 123 9 9 18 184 3 9 12 142 7 8 15 131 11 8 19 129 15 8 23 172 3 8 11 173 5 8 13 170 4 8 12 180 10 8 18 115 19 8 27 125 16 7 23 164 6 7 13 188 3 7 10 83 67 7 74 114 23 7 30 161 9 7 16 122 17 7 24 133 8 7 15 132 11 7 18 191 6 7 13 134 13 7 20 111 15 7 22 190 4 7 11 145 17 6 23 140 13 6 19 178 4 6 10 118 13 6 19 185 3 6 9 165 5 6 11 168 2 6 8 169 1 6 7 183 6 6 12 120 11 6 17 139 10 6 16 113 29 5 34 119 13 5 18 99 19 5 24 138 13 5 18 155 14 5 19 195 10 5 15 174 4 5 9 175 7 5 12 152 10 5 15 153 7 4 11 181 4 4 8 103 14 4 18 93 33 4 37 108 12 4 16 101 20 4 24 194 4 4 8 192 2 4 6 193 2 4 6 143 5 4 9 149 16 4 20 171 5 4 9 160 8 4 12 159 3 4 7 128 20 4 24 148 7 4 11 162 7 3 10 112 23 3 26 110 16 3 19 124 9 3 12 105 17 3 20 104 17 3 20 102 13 3 16 109 15 3 18 95 22 3 25 150 9 2 11 94 24 2 26 163 7 2 9 91 35 2 37 98 26 2 28 89 32 2 34 121 18 2 20 85 63 2 65 144 5 2 7 65 59 1 60 71 42 1 43 69 45 1 46 100 9 1 10 60 51 1 52 189 1 1 2 73 43 1 44 151 3 1 4 201 4 1 5 64 59 1 60 202 1 1 2 81 82 1 83 92 28 1 29 90 37 1 38 75 46 1 47 82 60 1 61 84 62 1 63 203 1 1 2 33 51 0 51 198 3 0 3 31 55 0 55 30 63 0 63 29 67 0 67 28 63 0 63 25 64 0 64 24 47 0 47 23 54 0 54 22 65 0 65 224 1 0 1 21 65 0 65 20 47 0 47 19 52 0 52 200 3 0 3 218 1 0 1 18 53 0 53 204 3 0 3 15 33 0 33 199 3 0 3 14 31 0 31 13 32 0 32 12 30 0 30 205 2 0 2 11 27 0 27 10 23 0 23 32 58 0 58 48 44 0 44 34 53 0 53 35 65 0 65 9 26 0 26 88 26 0 26 80 56 0 56 79 53 0 53 78 61 0 61 74 45 0 45 72 41 0 41 70 47 0 47 68 35 0 35 63 46 0 46 62 55 0 55 61 57 0 57 59 53 0 53 58 55 0 55 55 61 0 61 54 52 0 52 53 57 0 57 52 47 0 47 51 41 0 41 50 45 0 45 49 52 0 52 45 69 0 69 44 85 0 85 43 70 0 70 42 77 0 77 41 82 0 82 40 78 0 78 39 81 0 81 38 84 0 84 8 23 0 23 ------------------------------------------------------------------------------------------------------------------------
distribution_plot_wrt_target(data, "Income", "Personal_Loan")
stacked_barplot(data, "Experience", "Personal_Loan")
Personal_Loan 0 1 All Experience All 4520 480 5000 9 127 20 147 8 101 18 119 20 131 17 148 3 112 17 129 12 86 16 102 32 140 14 154 19 121 14 135 5 132 14 146 25 128 14 142 26 120 14 134 37 103 13 116 11 103 13 116 16 114 13 127 30 113 13 126 22 111 13 124 35 130 13 143 23 131 13 144 36 102 12 114 29 112 12 124 7 109 12 121 6 107 12 119 18 125 12 137 31 92 12 104 28 127 11 138 21 102 11 113 13 106 11 117 17 114 11 125 34 115 10 125 39 75 10 85 27 115 10 125 4 104 9 113 2 76 9 85 24 123 8 131 1 66 8 74 38 80 8 88 10 111 7 118 33 110 7 117 0 59 7 66 41 36 7 43 14 121 6 127 15 114 5 119 40 53 4 57 42 8 0 8 43 3 0 3 -2 15 0 15 -1 33 0 33 -3 4 0 4 ------------------------------------------------------------------------------------------------------------------------
distribution_plot_wrt_target(data, "Experience", "Personal_Loan")
stacked_barplot(data, "Family", "Personal_Loan")
Personal_Loan 0 1 All Family All 4520 480 5000 4 1088 134 1222 3 877 133 1010 1 1365 107 1472 2 1190 106 1296 ------------------------------------------------------------------------------------------------------------------------
distribution_plot_wrt_target(data, "Family", "Personal_Loan")
stacked_barplot(data, "CCAvg", "Personal_Loan")
Personal_Loan 0 1 All CCAvg All 4520 480 5000 3.0 34 19 53 4.1 9 13 22 3.4 26 13 39 3.1 8 12 20 4.2 0 11 11 5.4 8 10 18 6.5 8 10 18 3.8 33 10 43 3.6 17 10 27 3.3 35 10 45 5.0 9 9 18 3.9 18 9 27 2.9 45 9 54 2.6 79 8 87 6.0 18 8 26 4.4 9 8 17 4.3 18 8 26 0.2 196 8 204 0.5 155 8 163 4.7 17 7 24 5.2 9 7 16 1.3 121 7 128 2.7 51 7 58 3.7 18 7 25 1.1 77 7 84 5.6 0 7 7 4.0 26 7 33 2.2 123 7 130 4.8 0 7 7 5.1 0 6 6 0.7 163 6 169 6.1 8 6 14 1.2 60 6 66 3.5 9 6 15 4.6 8 6 14 0.3 235 6 241 0.8 182 5 187 6.9 9 5 14 4.9 17 5 22 6.3 8 5 13 3.2 17 5 22 2.3 53 5 58 1.4 131 5 136 2.8 105 5 110 7.0 9 5 14 5.7 8 5 13 2.4 87 5 92 5.9 0 5 5 1.9 102 4 106 1.7 154 4 158 7.9 0 4 4 4.5 25 4 29 0.6 114 4 118 5.5 0 4 4 2.0 184 4 188 0.4 175 4 179 5.3 0 4 4 7.4 9 4 13 1.5 174 4 178 7.2 9 4 13 6.6 0 4 4 1.6 98 3 101 10.0 0 3 3 6.4 0 3 3 0.9 103 3 106 8.0 9 3 12 7.5 9 3 12 2.1 97 3 100 1.8 149 3 152 5.8 0 3 3 6.2 0 2 2 9.0 0 2 2 8.5 0 2 2 6.8 8 2 10 8.3 0 2 2 4.25 0 2 2 5.67 0 2 2 4.75 0 2 2 0.1 181 2 183 1.0 229 2 231 2.5 105 2 107 7.3 9 1 10 9.3 0 1 1 8.9 0 1 1 8.8 8 1 9 8.2 0 1 1 8.1 9 1 10 5.33 0 1 1 0.0 105 1 106 3.67 0 1 1 3.25 0 1 1 3.33 0 1 1 4.67 0 1 1 6.33 9 1 10 2.75 0 1 1 2.67 36 0 36 0.67 18 0 18 0.75 9 0 9 4.33 9 0 9 8.6 8 0 8 6.67 9 0 9 1.33 9 0 9 6.7 9 0 9 1.67 18 0 18 1.75 9 0 9 7.8 9 0 9 7.6 9 0 9 2.33 18 0 18 ------------------------------------------------------------------------------------------------------------------------
distribution_plot_wrt_target(data, "CCAvg", "Personal_Loan")
stacked_barplot(data, "Mortgage", "Personal_Loan")
Personal_Loan 0 1 All Mortgage All 4520 480 5000 0 3150 312 3462 301 0 5 5 342 1 3 4 282 0 3 3 ... ... ... ... 276 2 0 2 156 5 0 5 278 1 0 1 280 2 0 2 248 3 0 3 [348 rows x 3 columns] ------------------------------------------------------------------------------------------------------------------------
distribution_plot_wrt_target(data, "Mortgage", "Personal_Loan")
stacked_barplot(data, "Securities_Account", "Personal_Loan")
Personal_Loan 0 1 All Securities_Account All 4520 480 5000 0 4058 420 4478 1 462 60 522 ------------------------------------------------------------------------------------------------------------------------
stacked_barplot(data, "CD_Account", "Personal_Loan")
Personal_Loan 0 1 All CD_Account All 4520 480 5000 0 4358 340 4698 1 162 140 302 ------------------------------------------------------------------------------------------------------------------------
Customers having CD account have more tendancy to have personal loan.
stacked_barplot(data, "Online", "Personal_Loan")
Personal_Loan 0 1 All Online All 4520 480 5000 1 2693 291 2984 0 1827 189 2016 ------------------------------------------------------------------------------------------------------------------------
Customer using Internet banking doesn't seem to influence Personal loan
#Find out number of customers havin credit cards
data['CreditCard'].value_counts()
CreditCard 0 3530 1 1470 Name: count, dtype: int64
Customers having Credit Card does not see to have any direct realtionship on personal loans.
stacked_barplot(data, "CreditCard", "Personal_Loan")
Personal_Loan 0 1 All CreditCard All 4520 480 5000 0 3193 337 3530 1 1327 143 1470 ------------------------------------------------------------------------------------------------------------------------
# Find out correaltion
numeric_df = data.select_dtypes(include=['float64', 'int64'])
corr = numeric_df.corr()
# plot the heatmap
plt.figure(figsize=(15, 7))
sns.heatmap(corr, annot=True, vmin=-1, vmax=1, fmt=".2f", cmap="Spectral")
plt.show()
There is higher positive correlation between Age and Experience 0.99 There is positive correlation between CCAvg and Income and Income and Personal loan.
#Probe Mortage data for outliers
Q1 = data['Mortgage'].quantile(0.25) # To find the 25th percentile and 75th percentile.
Q3 = data['Mortgage'].quantile(0.75)
IQR = Q3 - Q1 # Inter Quantile Range (75th perentile - 25th percentile)
lower = (
Q1 - 1.5 * IQR
) # Finding lower and upper bounds for all values. All values outside these bounds are outliers
upper = Q3 + 1.5 * IQR
print("Lower bound for Mortgage",lower)
print("Upper bound for Mortgage",upper)
Lower bound for Mortgage -151.5 Upper bound for Mortgage 252.5
Questions:
How many customers have credit cards?
What are the attributes that have a strong correlation with the target attribute (personal loan)?
How does a customer's interest in purchasing a loan vary with their age?
How does a customer's interest in purchasing a loan vary with their education?
#Installing sqlalchemy library as US Zip code is compatible with this version
!pip install sqlalchemy-mate==1.4.28.3
Collecting sqlalchemy-mate==1.4.28.3
Downloading sqlalchemy_mate-1.4.28.3-py2.py3-none-any.whl (76 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 76.4/76.4 kB 2.0 MB/s eta 0:00:00
Requirement already satisfied: sqlalchemy>=1.4.1<2.0.0 in /usr/local/lib/python3.10/dist-packages (from sqlalchemy-mate==1.4.28.3) (2.0.30)
Requirement already satisfied: prettytable in /usr/local/lib/python3.10/dist-packages (from sqlalchemy-mate==1.4.28.3) (3.10.0)
Requirement already satisfied: typing-extensions>=4.6.0 in /usr/local/lib/python3.10/dist-packages (from sqlalchemy>=1.4.1<2.0.0->sqlalchemy-mate==1.4.28.3) (4.11.0)
Requirement already satisfied: greenlet!=0.4.17 in /usr/local/lib/python3.10/dist-packages (from sqlalchemy>=1.4.1<2.0.0->sqlalchemy-mate==1.4.28.3) (3.0.3)
Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prettytable->sqlalchemy-mate==1.4.28.3) (0.2.13)
Installing collected packages: sqlalchemy-mate
Successfully installed sqlalchemy-mate-1.4.28.3
#Installing USzipcode library to feature engineer zipcode
!pip install uszipcode
Collecting uszipcode
Downloading uszipcode-1.0.1-py2.py3-none-any.whl (35 kB)
Requirement already satisfied: attrs in /usr/local/lib/python3.10/dist-packages (from uszipcode) (23.2.0)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from uszipcode) (2.31.0)
Collecting pathlib-mate (from uszipcode)
Downloading pathlib_mate-1.3.2-py3-none-any.whl (56 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.0/57.0 kB 3.2 MB/s eta 0:00:00
Collecting atomicwrites (from uszipcode)
Downloading atomicwrites-1.4.1.tar.gz (14 kB)
Preparing metadata (setup.py) ... done
Collecting fuzzywuzzy (from uszipcode)
Downloading fuzzywuzzy-0.18.0-py2.py3-none-any.whl (18 kB)
Collecting haversine>=2.5.0 (from uszipcode)
Downloading haversine-2.8.1-py2.py3-none-any.whl (7.7 kB)
Requirement already satisfied: SQLAlchemy>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from uszipcode) (2.0.30)
Requirement already satisfied: sqlalchemy-mate>=1.4.28.3 in /usr/local/lib/python3.10/dist-packages (from uszipcode) (1.4.28.3)
Requirement already satisfied: typing-extensions>=4.6.0 in /usr/local/lib/python3.10/dist-packages (from SQLAlchemy>=1.4.0->uszipcode) (4.11.0)
Requirement already satisfied: greenlet!=0.4.17 in /usr/local/lib/python3.10/dist-packages (from SQLAlchemy>=1.4.0->uszipcode) (3.0.3)
Requirement already satisfied: prettytable in /usr/local/lib/python3.10/dist-packages (from sqlalchemy-mate>=1.4.28.3->uszipcode) (3.10.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->uszipcode) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->uszipcode) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->uszipcode) (2.0.7)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->uszipcode) (2024.2.2)
Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prettytable->sqlalchemy-mate>=1.4.28.3->uszipcode) (0.2.13)
Building wheels for collected packages: atomicwrites
Building wheel for atomicwrites (setup.py) ... done
Created wheel for atomicwrites: filename=atomicwrites-1.4.1-py2.py3-none-any.whl size=6941 sha256=1b971d7c15f23230e05e77168e5ec33c9b346a59d05f1010f7a720c17ef61023
Stored in directory: /root/.cache/pip/wheels/34/07/0b/33b15f68736109f72ea0bb2499521d87312b932620737447a2
Successfully built atomicwrites
Installing collected packages: fuzzywuzzy, pathlib-mate, haversine, atomicwrites, uszipcode
Successfully installed atomicwrites-1.4.1 fuzzywuzzy-0.18.0 haversine-2.8.1 pathlib-mate-1.3.2 uszipcode-1.0.1
#Find out major Metro for Zip code, few Zipcodes need to be feature engineered as they were not resolving to major city
from uszipcode import SearchEngine, SimpleZipcode
search = SearchEngine()
def zco(x):
city = search.by_zipcode(x)
if city :
return city.major_city
elif (x == 92717 or x == 92634 or x== 96651 ):
return "Los Angeles"
elif (x == 93077):
return "Fresno"
else:
return "None"
data['metro'] = data['ZIPCode'].apply(zco)
Download /root/.uszipcode/simple_db.sqlite from https://github.com/MacHu-GWU/uszipcode-project/releases/download/1.0.1.db/simple_db.sqlite ... 1.00 MB downloaded ... 2.00 MB downloaded ... 3.00 MB downloaded ... 4.00 MB downloaded ... 5.00 MB downloaded ... 6.00 MB downloaded ... 7.00 MB downloaded ... 8.00 MB downloaded ... 9.00 MB downloaded ... 10.00 MB downloaded ... 11.00 MB downloaded ... Complete!
#Understand distribtion of customers across cities
data['metro'].value_counts()
metro
Los Angeles 408
San Diego 269
San Francisco 257
Berkeley 241
Sacramento 148
...
Sausalito 1
Ladera Ranch 1
Sierra Madre 1
Tahoe City 1
Stinson Beach 1
Name: count, Length: 244, dtype: int64
#Check for any zip code not being matched
zipData = data[data['metro'] == 'None']
zipData['ZIPCode'].value_counts()
Series([], Name: count, dtype: int64)
# checking for experience <0
data[data["Experience"] < 0]["Experience"].unique()
array([-1, -2, -3])
# Negative experience could be typo , convert Negative experience to postive
data["Experience"].replace(-1, 1, inplace=True)
data["Experience"].replace(-2, 2, inplace=True)
data["Experience"].replace(-3, 3, inplace=True)
#convert columns to categorical columns
cat_cols = [
"Education",
"Personal_Loan",
"Securities_Account",
"CD_Account",
"Online",
"CreditCard"
]
data[cat_cols] = data[cat_cols].astype("category")
#Dropping Target varibale and ZipCode
X = data.drop(["Personal_Loan","ZIPCode"], axis=1)
y = data["Personal_Loan"]
# encoding the categorical variables
X = pd.get_dummies(X, drop_first=True)
X.head()
| Age | Experience | Income | Family | CCAvg | Mortgage | Education_2 | Education_3 | Securities_Account_1 | CD_Account_1 | Online_1 | CreditCard_1 | metro_Alameda | metro_Alamo | metro_Albany | metro_Alhambra | metro_Anaheim | metro_Antioch | metro_Aptos | metro_Arcadia | metro_Arcata | metro_Bakersfield | metro_Baldwin Park | metro_Banning | metro_Bella Vista | metro_Belmont | metro_Belvedere Tiburon | metro_Ben Lomond | metro_Berkeley | metro_Beverly Hills | metro_Bodega Bay | metro_Bonita | metro_Boulder Creek | metro_Brea | metro_Brisbane | metro_Burlingame | metro_Calabasas | metro_Camarillo | metro_Campbell | metro_Canoga Park | metro_Capistrano Beach | metro_Capitola | metro_Cardiff By The Sea | metro_Carlsbad | metro_Carpinteria | metro_Carson | metro_Castro Valley | metro_Ceres | metro_Chatsworth | metro_Chico | metro_Chino | metro_Chino Hills | metro_Chula Vista | metro_Citrus Heights | metro_Claremont | metro_Clearlake | metro_Clovis | metro_Concord | metro_Costa Mesa | metro_Crestline | metro_Culver City | metro_Cupertino | metro_Cypress | metro_Daly City | metro_Danville | metro_Davis | metro_Diamond Bar | metro_Edwards | metro_El Dorado Hills | metro_El Segundo | metro_El Sobrante | metro_Elk Grove | metro_Emeryville | metro_Encinitas | metro_Escondido | metro_Eureka | metro_Fairfield | metro_Fallbrook | metro_Fawnskin | metro_Folsom | metro_Fremont | metro_Fresno | metro_Fullerton | metro_Garden Grove | metro_Gilroy | metro_Glendale | metro_Glendora | metro_Goleta | metro_Greenbrae | metro_Hacienda Heights | metro_Half Moon Bay | metro_Hawthorne | metro_Hayward | metro_Hermosa Beach | metro_Highland | metro_Hollister | metro_Hopland | metro_Huntington Beach | metro_Imperial | metro_Inglewood | metro_Irvine | metro_La Jolla | metro_La Mesa | metro_La Mirada | metro_La Palma | metro_Ladera Ranch | metro_Laguna Hills | metro_Laguna Niguel | metro_Lake Forest | metro_Larkspur | metro_Livermore | metro_Loma Linda | metro_Lomita | metro_Lompoc | metro_Long Beach | metro_Los Alamitos | metro_Los Altos | metro_Los Angeles | metro_Los Gatos | metro_Manhattan Beach | metro_March Air Reserve Base | metro_Marina | metro_Martinez | metro_Menlo Park | metro_Merced | metro_Milpitas | metro_Mission Hills | metro_Mission Viejo | metro_Modesto | metro_Monrovia | metro_Montague | metro_Montclair | metro_Montebello | metro_Monterey | metro_Monterey Park | metro_Moraga | metro_Morgan Hill | metro_Moss Landing | metro_Mountain View | metro_Napa | metro_National City | metro_Newbury Park | metro_Newport Beach | metro_North Hills | metro_North Hollywood | metro_Northridge | metro_Norwalk | metro_Novato | metro_Oak View | metro_Oakland | metro_Oceanside | metro_Ojai | metro_Orange | metro_Oxnard | metro_Pacific Grove | metro_Pacific Palisades | metro_Palo Alto | metro_Palos Verdes Peninsula | metro_Pasadena | metro_Placentia | metro_Pleasant Hill | metro_Pleasanton | metro_Pomona | metro_Porter Ranch | metro_Portola Valley | metro_Poway | metro_Rancho Cordova | metro_Rancho Cucamonga | metro_Rancho Palos Verdes | metro_Redding | metro_Redlands | metro_Redondo Beach | metro_Redwood City | metro_Reseda | metro_Richmond | metro_Ridgecrest | metro_Rio Vista | metro_Riverside | metro_Rohnert Park | metro_Rosemead | metro_Roseville | metro_Sacramento | metro_Salinas | metro_San Anselmo | metro_San Bernardino | metro_San Bruno | metro_San Clemente | metro_San Diego | metro_San Dimas | metro_San Francisco | metro_San Gabriel | metro_San Jose | metro_San Juan Bautista | metro_San Juan Capistrano | metro_San Leandro | metro_San Luis Obispo | metro_San Luis Rey | metro_San Marcos | metro_San Mateo | metro_San Pablo | metro_San Rafael | metro_San Ramon | metro_San Ysidro | metro_Sanger | metro_Santa Ana | metro_Santa Barbara | metro_Santa Clara | metro_Santa Clarita | metro_Santa Cruz | metro_Santa Monica | metro_Santa Rosa | metro_Santa Ynez | metro_Saratoga | metro_Sausalito | metro_Seal Beach | metro_Seaside | metro_Sherman Oaks | metro_Sierra Madre | metro_Signal Hill | metro_Simi Valley | metro_Sonora | metro_South Gate | metro_South Lake Tahoe | metro_South Pasadena | metro_South San Francisco | metro_Stanford | metro_Stinson Beach | metro_Stockton | metro_Studio City | metro_Sunland | metro_Sunnyvale | metro_Sylmar | metro_Tahoe City | metro_Tehachapi | metro_Thousand Oaks | metro_Torrance | metro_Trinity Center | metro_Tustin | metro_Ukiah | metro_Upland | metro_Valencia | metro_Vallejo | metro_Van Nuys | metro_Venice | metro_Ventura | metro_Vista | metro_Walnut Creek | metro_Weed | metro_West Covina | metro_West Sacramento | metro_Westlake Village | metro_Whittier | metro_Woodland Hills | metro_Yorba Linda | metro_Yucaipa | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 25 | 1 | 49 | 4 | 1.6 | 0 | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False |
| 1 | 45 | 19 | 34 | 3 | 1.5 | 0 | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False |
| 2 | 39 | 15 | 11 | 1 | 1.0 | 0 | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False |
| 3 | 35 | 9 | 100 | 1 | 2.7 | 0 | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False |
| 4 | 35 | 8 | 45 | 4 | 1.0 | 0 | True | False | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)
print("Number of rows in train data =", X_train.shape[0])
print("Number of rows in test data =", X_test.shape[0])
Number of rows in train data = 3500 Number of rows in test data = 1500
print("Percentage of classes in training set:")
print(y_train.value_counts(normalize=True))
print("Percentage of classes in test set:")
print(y_test.value_counts(normalize=True))
Percentage of classes in training set: Personal_Loan 0 0.905429 1 0.094571 Name: proportion, dtype: float64 Percentage of classes in test set: Personal_Loan 0 0.900667 1 0.099333 Name: proportion, dtype: float64
Predicting a customer as not eligible for Personal Loan but in reality the customer will be a good candidate for Personal Loan.
Predicting a customer as eligble for Personal loan but in reality the customer would not be a good candidate for Personal Loan.
recall should be maximized, the greater the recall higher the chances of minimizing the false negatives.model = DecisionTreeClassifier(criterion="gini", random_state=1)
model.fit(X_train, y_train)
DecisionTreeClassifier(random_state=1)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
DecisionTreeClassifier(random_state=1)
First, let's create functions to calculate different metrics and confusion matrix so that we don't have to use the same code repeatedly for each model. The model_performance_classification_sklearn function will be used to check the model performance of models. The make_confusion_matrix function will be used to plot confusion matrix.
# defining a function to compute different metrics to check performance of a classification model built using sklearn
def model_performance_classification_sklearn(model, predictors, target):
"""
Function to compute different metrics to check classification model performance
model: classifier
predictors: independent variables
target: dependent variable
"""
# predicting using the independent variables
pred = model.predict(predictors)
acc = accuracy_score(target, pred) # to compute Accuracy
recall = recall_score(target, pred) # to compute Recall
precision = precision_score(target, pred) # to compute Precision
f1 = f1_score(target, pred) # to compute F1-score
# creating a dataframe of metrics
df_perf = pd.DataFrame(
{"Accuracy": acc, "Recall": recall, "Precision": precision, "F1": f1,},
index=[0],
)
return df_perf
def confusion_matrix_sklearn(model, predictors, target):
"""
To plot the confusion_matrix with percentages
model: classifier
predictors: independent variables
target: dependent variable
"""
y_pred = model.predict(predictors)
cm = confusion_matrix(target, y_pred)
labels = np.asarray(
[
["{0:0.0f}".format(item) + "\n{0:.2%}".format(item / cm.flatten().sum())]
for item in cm.flatten()
]
).reshape(2, 2)
plt.figure(figsize=(6, 4))
sns.heatmap(cm, annot=labels, fmt="")
plt.ylabel("True label")
plt.xlabel("Predicted label")
decision_tree_perf_train = model_performance_classification_sklearn(
model, X_train, y_train
)
decision_tree_perf_train
| Accuracy | Recall | Precision | F1 | |
|---|---|---|---|---|
| 0 | 1.0 | 1.0 | 1.0 | 1.0 |
Above model for training data looks to be overfitting as accuracy,recall,precision and F1 are 1.0, further model refinement may be needed
confusion_matrix_sklearn(model, X_train, y_train)
decision_tree_perf_test = model_performance_classification_sklearn(
model, X_test, y_test
)
decision_tree_perf_test
| Accuracy | Recall | Precision | F1 | |
|---|---|---|---|---|
| 0 | 0.982667 | 0.879195 | 0.942446 | 0.909722 |
confusion_matrix_sklearn(model, X_test, y_test)
column_names = list(X.columns)
feature_names = column_names
print(feature_names)
['Age', 'Experience', 'Income', 'Family', 'CCAvg', 'Mortgage', 'Education_2', 'Education_3', 'Securities_Account_1', 'CD_Account_1', 'Online_1', 'CreditCard_1', 'metro_Alameda', 'metro_Alamo', 'metro_Albany', 'metro_Alhambra', 'metro_Anaheim', 'metro_Antioch', 'metro_Aptos', 'metro_Arcadia', 'metro_Arcata', 'metro_Bakersfield', 'metro_Baldwin Park', 'metro_Banning', 'metro_Bella Vista', 'metro_Belmont', 'metro_Belvedere Tiburon', 'metro_Ben Lomond', 'metro_Berkeley', 'metro_Beverly Hills', 'metro_Bodega Bay', 'metro_Bonita', 'metro_Boulder Creek', 'metro_Brea', 'metro_Brisbane', 'metro_Burlingame', 'metro_Calabasas', 'metro_Camarillo', 'metro_Campbell', 'metro_Canoga Park', 'metro_Capistrano Beach', 'metro_Capitola', 'metro_Cardiff By The Sea', 'metro_Carlsbad', 'metro_Carpinteria', 'metro_Carson', 'metro_Castro Valley', 'metro_Ceres', 'metro_Chatsworth', 'metro_Chico', 'metro_Chino', 'metro_Chino Hills', 'metro_Chula Vista', 'metro_Citrus Heights', 'metro_Claremont', 'metro_Clearlake', 'metro_Clovis', 'metro_Concord', 'metro_Costa Mesa', 'metro_Crestline', 'metro_Culver City', 'metro_Cupertino', 'metro_Cypress', 'metro_Daly City', 'metro_Danville', 'metro_Davis', 'metro_Diamond Bar', 'metro_Edwards', 'metro_El Dorado Hills', 'metro_El Segundo', 'metro_El Sobrante', 'metro_Elk Grove', 'metro_Emeryville', 'metro_Encinitas', 'metro_Escondido', 'metro_Eureka', 'metro_Fairfield', 'metro_Fallbrook', 'metro_Fawnskin', 'metro_Folsom', 'metro_Fremont', 'metro_Fresno', 'metro_Fullerton', 'metro_Garden Grove', 'metro_Gilroy', 'metro_Glendale', 'metro_Glendora', 'metro_Goleta', 'metro_Greenbrae', 'metro_Hacienda Heights', 'metro_Half Moon Bay', 'metro_Hawthorne', 'metro_Hayward', 'metro_Hermosa Beach', 'metro_Highland', 'metro_Hollister', 'metro_Hopland', 'metro_Huntington Beach', 'metro_Imperial', 'metro_Inglewood', 'metro_Irvine', 'metro_La Jolla', 'metro_La Mesa', 'metro_La Mirada', 'metro_La Palma', 'metro_Ladera Ranch', 'metro_Laguna Hills', 'metro_Laguna Niguel', 'metro_Lake Forest', 'metro_Larkspur', 'metro_Livermore', 'metro_Loma Linda', 'metro_Lomita', 'metro_Lompoc', 'metro_Long Beach', 'metro_Los Alamitos', 'metro_Los Altos', 'metro_Los Angeles', 'metro_Los Gatos', 'metro_Manhattan Beach', 'metro_March Air Reserve Base', 'metro_Marina', 'metro_Martinez', 'metro_Menlo Park', 'metro_Merced', 'metro_Milpitas', 'metro_Mission Hills', 'metro_Mission Viejo', 'metro_Modesto', 'metro_Monrovia', 'metro_Montague', 'metro_Montclair', 'metro_Montebello', 'metro_Monterey', 'metro_Monterey Park', 'metro_Moraga', 'metro_Morgan Hill', 'metro_Moss Landing', 'metro_Mountain View', 'metro_Napa', 'metro_National City', 'metro_Newbury Park', 'metro_Newport Beach', 'metro_North Hills', 'metro_North Hollywood', 'metro_Northridge', 'metro_Norwalk', 'metro_Novato', 'metro_Oak View', 'metro_Oakland', 'metro_Oceanside', 'metro_Ojai', 'metro_Orange', 'metro_Oxnard', 'metro_Pacific Grove', 'metro_Pacific Palisades', 'metro_Palo Alto', 'metro_Palos Verdes Peninsula', 'metro_Pasadena', 'metro_Placentia', 'metro_Pleasant Hill', 'metro_Pleasanton', 'metro_Pomona', 'metro_Porter Ranch', 'metro_Portola Valley', 'metro_Poway', 'metro_Rancho Cordova', 'metro_Rancho Cucamonga', 'metro_Rancho Palos Verdes', 'metro_Redding', 'metro_Redlands', 'metro_Redondo Beach', 'metro_Redwood City', 'metro_Reseda', 'metro_Richmond', 'metro_Ridgecrest', 'metro_Rio Vista', 'metro_Riverside', 'metro_Rohnert Park', 'metro_Rosemead', 'metro_Roseville', 'metro_Sacramento', 'metro_Salinas', 'metro_San Anselmo', 'metro_San Bernardino', 'metro_San Bruno', 'metro_San Clemente', 'metro_San Diego', 'metro_San Dimas', 'metro_San Francisco', 'metro_San Gabriel', 'metro_San Jose', 'metro_San Juan Bautista', 'metro_San Juan Capistrano', 'metro_San Leandro', 'metro_San Luis Obispo', 'metro_San Luis Rey', 'metro_San Marcos', 'metro_San Mateo', 'metro_San Pablo', 'metro_San Rafael', 'metro_San Ramon', 'metro_San Ysidro', 'metro_Sanger', 'metro_Santa Ana', 'metro_Santa Barbara', 'metro_Santa Clara', 'metro_Santa Clarita', 'metro_Santa Cruz', 'metro_Santa Monica', 'metro_Santa Rosa', 'metro_Santa Ynez', 'metro_Saratoga', 'metro_Sausalito', 'metro_Seal Beach', 'metro_Seaside', 'metro_Sherman Oaks', 'metro_Sierra Madre', 'metro_Signal Hill', 'metro_Simi Valley', 'metro_Sonora', 'metro_South Gate', 'metro_South Lake Tahoe', 'metro_South Pasadena', 'metro_South San Francisco', 'metro_Stanford', 'metro_Stinson Beach', 'metro_Stockton', 'metro_Studio City', 'metro_Sunland', 'metro_Sunnyvale', 'metro_Sylmar', 'metro_Tahoe City', 'metro_Tehachapi', 'metro_Thousand Oaks', 'metro_Torrance', 'metro_Trinity Center', 'metro_Tustin', 'metro_Ukiah', 'metro_Upland', 'metro_Valencia', 'metro_Vallejo', 'metro_Van Nuys', 'metro_Venice', 'metro_Ventura', 'metro_Vista', 'metro_Walnut Creek', 'metro_Weed', 'metro_West Covina', 'metro_West Sacramento', 'metro_Westlake Village', 'metro_Whittier', 'metro_Woodland Hills', 'metro_Yorba Linda', 'metro_Yucaipa']
plt.figure(figsize=(20, 30))
out = tree.plot_tree(
model,
feature_names=feature_names,
filled=True,
fontsize=9,
node_ids=True,
class_names=True,
)
for o in out:
arrow = o.arrow_patch
if arrow is not None:
arrow.set_edgecolor("black")
arrow.set_linewidth(1)
plt.show()
# Text report showing the rules of a decision tree -
print(tree.export_text(model, feature_names=feature_names, show_weights=True))
|--- Income <= 116.50 | |--- CCAvg <= 2.95 | | |--- Income <= 106.50 | | | |--- weights: [2553.00, 0.00] class: 0 | | |--- Income > 106.50 | | | |--- Family <= 3.50 | | | | |--- metro_Cardiff By The Sea <= 0.50 | | | | | |--- metro_Santa Barbara <= 0.50 | | | | | | |--- Age <= 28.50 | | | | | | | |--- Education_2 <= 0.50 | | | | | | | | |--- weights: [5.00, 0.00] class: 0 | | | | | | | |--- Education_2 > 0.50 | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | |--- Age > 28.50 | | | | | | | |--- weights: [58.00, 0.00] class: 0 | | | | | |--- metro_Santa Barbara > 0.50 | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | |--- metro_Cardiff By The Sea > 0.50 | | | | | |--- weights: [0.00, 1.00] class: 1 | | | |--- Family > 3.50 | | | | |--- Experience <= 3.50 | | | | | |--- weights: [10.00, 0.00] class: 0 | | | | |--- Experience > 3.50 | | | | | |--- Age <= 60.00 | | | | | | |--- Experience <= 7.00 | | | | | | | |--- Experience <= 4.50 | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | | |--- Experience > 4.50 | | | | | | | | |--- weights: [2.00, 0.00] class: 0 | | | | | | |--- Experience > 7.00 | | | | | | | |--- weights: [0.00, 6.00] class: 1 | | | | | |--- Age > 60.00 | | | | | | |--- weights: [4.00, 0.00] class: 0 | |--- CCAvg > 2.95 | | |--- Income <= 92.50 | | | |--- CD_Account_1 <= 0.50 | | | | |--- Age <= 26.50 | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | |--- Age > 26.50 | | | | | |--- metro_Banning <= 0.50 | | | | | | |--- metro_Glendale <= 0.50 | | | | | | | |--- metro_San Francisco <= 0.50 | | | | | | | | |--- metro_Whittier <= 0.50 | | | | | | | | | |--- metro_Riverside <= 0.50 | | | | | | | | | | |--- Age <= 62.50 | | | | | | | | | | | |--- truncated branch of depth 5 | | | | | | | | | | |--- Age > 62.50 | | | | | | | | | | | |--- truncated branch of depth 2 | | | | | | | | | |--- metro_Riverside > 0.50 | | | | | | | | | | |--- Income <= 73.50 | | | | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | | | | | | |--- Income > 73.50 | | | | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | | | |--- metro_Whittier > 0.50 | | | | | | | | | |--- Family <= 2.50 | | | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | | | | |--- Family > 2.50 | | | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | | | |--- metro_San Francisco > 0.50 | | | | | | | | |--- Income <= 82.50 | | | | | | | | | |--- weights: [4.00, 0.00] class: 0 | | | | | | | | |--- Income > 82.50 | | | | | | | | | |--- weights: [0.00, 2.00] class: 1 | | | | | | |--- metro_Glendale > 0.50 | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | |--- metro_Banning > 0.50 | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | |--- CD_Account_1 > 0.50 | | | | |--- weights: [0.00, 5.00] class: 1 | | |--- Income > 92.50 | | | |--- Family <= 2.50 | | | | |--- Education_2 <= 0.50 | | | | | |--- Education_3 <= 0.50 | | | | | | |--- CD_Account_1 <= 0.50 | | | | | | | |--- Age <= 56.50 | | | | | | | | |--- weights: [27.00, 0.00] class: 0 | | | | | | | |--- Age > 56.50 | | | | | | | | |--- Online_1 <= 0.50 | | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | | | |--- Online_1 > 0.50 | | | | | | | | | |--- weights: [2.00, 0.00] class: 0 | | | | | | |--- CD_Account_1 > 0.50 | | | | | | | |--- CCAvg <= 4.75 | | | | | | | | |--- weights: [0.00, 2.00] class: 1 | | | | | | | |--- CCAvg > 4.75 | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | |--- Education_3 > 0.50 | | | | | | |--- Experience <= 25.50 | | | | | | | |--- CCAvg <= 3.95 | | | | | | | | |--- metro_Redondo Beach <= 0.50 | | | | | | | | | |--- weights: [0.00, 3.00] class: 1 | | | | | | | | |--- metro_Redondo Beach > 0.50 | | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | | | |--- CCAvg > 3.95 | | | | | | | | |--- weights: [6.00, 0.00] class: 0 | | | | | | |--- Experience > 25.50 | | | | | | | |--- weights: [0.00, 4.00] class: 1 | | | | |--- Education_2 > 0.50 | | | | | |--- weights: [0.00, 4.00] class: 1 | | | |--- Family > 2.50 | | | | |--- Age <= 57.50 | | | | | |--- metro_El Segundo <= 0.50 | | | | | | |--- weights: [0.00, 20.00] class: 1 | | | | | |--- metro_El Segundo > 0.50 | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | |--- Age > 57.50 | | | | | |--- CD_Account_1 <= 0.50 | | | | | | |--- metro_Oakland <= 0.50 | | | | | | | |--- weights: [7.00, 0.00] class: 0 | | | | | | |--- metro_Oakland > 0.50 | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | |--- CD_Account_1 > 0.50 | | | | | | |--- weights: [0.00, 2.00] class: 1 |--- Income > 116.50 | |--- Family <= 2.50 | | |--- Education_3 <= 0.50 | | | |--- Education_2 <= 0.50 | | | | |--- weights: [375.00, 0.00] class: 0 | | | |--- Education_2 > 0.50 | | | | |--- weights: [0.00, 53.00] class: 1 | | |--- Education_3 > 0.50 | | | |--- weights: [0.00, 62.00] class: 1 | |--- Family > 2.50 | | |--- weights: [0.00, 154.00] class: 1
importances = model.feature_importances_
indices = np.argsort(importances)
plt.figure(figsize=(12, 40))
plt.title("Feature Importances")
plt.barh(range(len(indices)), importances[indices], color="violet", align="center")
plt.yticks(range(len(indices)), [feature_names[i] for i in indices])
plt.xlabel("Relative Importance")
plt.show()
# Choose the type of classifier.
estimator = DecisionTreeClassifier(random_state=1)
# Grid of parameters to choose from
parameters = {
"max_depth": np.arange(6, 15),
"min_samples_leaf": [1, 2, 5, 7, 10],
"max_leaf_nodes": [2, 3, 5, 10],
"criterion": ["entropy", "gini"],
"splitter": ["best", "random"],
}
# Type of scoring used to compare parameter combinations
acc_scorer = make_scorer(recall_score)
# Run the grid search
grid_obj = GridSearchCV(estimator, parameters, scoring=acc_scorer, cv=5)
grid_obj = grid_obj.fit(X_train, y_train)
# Set the clf to the best combination of parameters
estimator = grid_obj.best_estimator_
# Fit the best algorithm to the data.
estimator.fit(X_train, y_train)
DecisionTreeClassifier(criterion='entropy', max_depth=6, max_leaf_nodes=10,
random_state=1)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. DecisionTreeClassifier(criterion='entropy', max_depth=6, max_leaf_nodes=10,
random_state=1)decision_tree_tune_perf_train = model_performance_classification_sklearn(
estimator, X_train, y_train
)
decision_tree_tune_perf_train
| Accuracy | Recall | Precision | F1 | |
|---|---|---|---|---|
| 0 | 0.985714 | 0.873112 | 0.973064 | 0.920382 |
confusion_matrix_sklearn(estimator, X_train, y_train)
decision_tree_tune_perf_test = model_performance_classification_sklearn(
estimator, X_test, y_test
)
decision_tree_tune_perf_test
| Accuracy | Recall | Precision | F1 | |
|---|---|---|---|---|
| 0 | 0.976667 | 0.785235 | 0.975 | 0.869888 |
confusion_matrix_sklearn(estimator, X_test, y_test)
plt.figure(figsize=(15, 15))
tree.plot_tree(
estimator,
feature_names=feature_names,
filled=True,
fontsize=9,
node_ids=True,
class_names=True,
)
plt.show()
clf = DecisionTreeClassifier(random_state=1)
path = clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities
pd.DataFrame(path)
| ccp_alphas | impurities | |
|---|---|---|
| 0 | 0.000000 | 0.000000 |
| 1 | 0.000269 | 0.000538 |
| 2 | 0.000275 | 0.002736 |
| 3 | 0.000276 | 0.003288 |
| 4 | 0.000281 | 0.003851 |
| 5 | 0.000381 | 0.004232 |
| 6 | 0.000381 | 0.004613 |
| 7 | 0.000429 | 0.005041 |
| 8 | 0.000500 | 0.005541 |
| 9 | 0.000506 | 0.008070 |
| 10 | 0.000508 | 0.008578 |
| 11 | 0.000537 | 0.009651 |
| 12 | 0.000544 | 0.010196 |
| 13 | 0.000625 | 0.010821 |
| 14 | 0.000700 | 0.011521 |
| 15 | 0.000771 | 0.012292 |
| 16 | 0.000792 | 0.015460 |
| 17 | 0.000800 | 0.016260 |
| 18 | 0.000940 | 0.017200 |
| 19 | 0.001305 | 0.018505 |
| 20 | 0.001647 | 0.020153 |
| 21 | 0.002333 | 0.022486 |
| 22 | 0.002407 | 0.024893 |
| 23 | 0.003294 | 0.028187 |
| 24 | 0.006473 | 0.034659 |
| 25 | 0.025146 | 0.084951 |
| 26 | 0.039216 | 0.124167 |
| 27 | 0.047088 | 0.171255 |
fig, ax = plt.subplots(figsize=(15, 5))
ax.plot(ccp_alphas[:-1], impurities[:-1], marker="o", drawstyle="steps-post")
ax.set_xlabel("effective alpha")
ax.set_ylabel("total impurity of leaves")
ax.set_title("Total Impurity vs effective alpha for training set")
plt.show()
Next, we train a decision tree using the effective alphas. The last value
in ccp_alphas is the alpha value that prunes the whole tree,
leaving the tree, clfs[-1], with one node.
clfs = []
for ccp_alpha in ccp_alphas:
clf = DecisionTreeClassifier(random_state=1, ccp_alpha=ccp_alpha)
clf.fit(X_train, y_train)
clfs.append(clf)
print(
"Number of nodes in the last tree is: {} with ccp_alpha: {}".format(
clfs[-1].tree_.node_count, ccp_alphas[-1]
)
)
Number of nodes in the last tree is: 1 with ccp_alpha: 0.04708834100596766
For the remainder, we remove the last element in
clfs and ccp_alphas, because it is the trivial tree with only one
node. Here we show that the number of nodes and tree depth decreases as alpha
increases.
clfs = clfs[:-1]
ccp_alphas = ccp_alphas[:-1]
node_counts = [clf.tree_.node_count for clf in clfs]
depth = [clf.tree_.max_depth for clf in clfs]
fig, ax = plt.subplots(2, 1, figsize=(10, 7))
ax[0].plot(ccp_alphas, node_counts, marker="o", drawstyle="steps-post")
ax[0].set_xlabel("alpha")
ax[0].set_ylabel("number of nodes")
ax[0].set_title("Number of nodes vs alpha")
ax[1].plot(ccp_alphas, depth, marker="o", drawstyle="steps-post")
ax[1].set_xlabel("alpha")
ax[1].set_ylabel("depth of tree")
ax[1].set_title("Depth vs alpha")
fig.tight_layout()
recall_train = []
for clf in clfs:
pred_train = clf.predict(X_train)
values_train = recall_score(y_train, pred_train)
recall_train.append(values_train)
recall_test = []
for clf in clfs:
pred_test = clf.predict(X_test)
values_test = recall_score(y_test, pred_test)
recall_test.append(values_test)
fig, ax = plt.subplots(figsize=(15, 5))
ax.set_xlabel("alpha")
ax.set_ylabel("Recall")
ax.set_title("Recall vs alpha for training and testing sets")
ax.plot(ccp_alphas, recall_train, marker="o", label="train", drawstyle="steps-post")
ax.plot(ccp_alphas, recall_test, marker="o", label="test", drawstyle="steps-post")
ax.legend()
plt.show()
# creating the model where we get highest train and test recall
index_best_model = np.argmax(recall_test)
best_model = clfs[index_best_model]
print(best_model)
DecisionTreeClassifier(random_state=1)
decision_tree_postpruned_perf_train = model_performance_classification_sklearn(
best_model, X_train, y_train
)
decision_tree_postpruned_perf_train
| Accuracy | Recall | Precision | F1 | |
|---|---|---|---|---|
| 0 | 1.0 | 1.0 | 1.0 | 1.0 |
confusion_matrix_sklearn(best_model, X_train, y_train)
decision_tree_postpruned_perf_test = model_performance_classification_sklearn(
best_model, X_test, y_test
)
decision_tree_postpruned_perf_test
| Accuracy | Recall | Precision | F1 | |
|---|---|---|---|---|
| 0 | 0.982667 | 0.879195 | 0.942446 | 0.909722 |
confusion_matrix_sklearn(best_model, X_test, y_test)
plt.figure(figsize=(10, 20))
out = tree.plot_tree(
best_model,
feature_names=feature_names,
filled=True,
fontsize=9,
node_ids=True,
class_names=True,
)
for o in out:
arrow = o.arrow_patch
if arrow is not None:
arrow.set_edgecolor("black")
arrow.set_linewidth(1)
plt.show()
plt.show()
# Text report showing the rules of a decision tree -
print(tree.export_text(best_model, feature_names=feature_names, show_weights=True))
|--- Income <= 116.50 | |--- CCAvg <= 2.95 | | |--- Income <= 106.50 | | | |--- weights: [2553.00, 0.00] class: 0 | | |--- Income > 106.50 | | | |--- Family <= 3.50 | | | | |--- metro_Cardiff By The Sea <= 0.50 | | | | | |--- metro_Santa Barbara <= 0.50 | | | | | | |--- Age <= 28.50 | | | | | | | |--- Education_2 <= 0.50 | | | | | | | | |--- weights: [5.00, 0.00] class: 0 | | | | | | | |--- Education_2 > 0.50 | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | |--- Age > 28.50 | | | | | | | |--- weights: [58.00, 0.00] class: 0 | | | | | |--- metro_Santa Barbara > 0.50 | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | |--- metro_Cardiff By The Sea > 0.50 | | | | | |--- weights: [0.00, 1.00] class: 1 | | | |--- Family > 3.50 | | | | |--- Experience <= 3.50 | | | | | |--- weights: [10.00, 0.00] class: 0 | | | | |--- Experience > 3.50 | | | | | |--- Age <= 60.00 | | | | | | |--- Experience <= 7.00 | | | | | | | |--- Experience <= 4.50 | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | | |--- Experience > 4.50 | | | | | | | | |--- weights: [2.00, 0.00] class: 0 | | | | | | |--- Experience > 7.00 | | | | | | | |--- weights: [0.00, 6.00] class: 1 | | | | | |--- Age > 60.00 | | | | | | |--- weights: [4.00, 0.00] class: 0 | |--- CCAvg > 2.95 | | |--- Income <= 92.50 | | | |--- CD_Account_1 <= 0.50 | | | | |--- Age <= 26.50 | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | |--- Age > 26.50 | | | | | |--- metro_Banning <= 0.50 | | | | | | |--- metro_Glendale <= 0.50 | | | | | | | |--- metro_San Francisco <= 0.50 | | | | | | | | |--- metro_Whittier <= 0.50 | | | | | | | | | |--- metro_Riverside <= 0.50 | | | | | | | | | | |--- Age <= 62.50 | | | | | | | | | | | |--- truncated branch of depth 5 | | | | | | | | | | |--- Age > 62.50 | | | | | | | | | | | |--- truncated branch of depth 2 | | | | | | | | | |--- metro_Riverside > 0.50 | | | | | | | | | | |--- Income <= 73.50 | | | | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | | | | | | |--- Income > 73.50 | | | | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | | | |--- metro_Whittier > 0.50 | | | | | | | | | |--- Family <= 2.50 | | | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | | | | |--- Family > 2.50 | | | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | | | |--- metro_San Francisco > 0.50 | | | | | | | | |--- Income <= 82.50 | | | | | | | | | |--- weights: [4.00, 0.00] class: 0 | | | | | | | | |--- Income > 82.50 | | | | | | | | | |--- weights: [0.00, 2.00] class: 1 | | | | | | |--- metro_Glendale > 0.50 | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | |--- metro_Banning > 0.50 | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | |--- CD_Account_1 > 0.50 | | | | |--- weights: [0.00, 5.00] class: 1 | | |--- Income > 92.50 | | | |--- Family <= 2.50 | | | | |--- Education_2 <= 0.50 | | | | | |--- Education_3 <= 0.50 | | | | | | |--- CD_Account_1 <= 0.50 | | | | | | | |--- Age <= 56.50 | | | | | | | | |--- weights: [27.00, 0.00] class: 0 | | | | | | | |--- Age > 56.50 | | | | | | | | |--- Online_1 <= 0.50 | | | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | | | | |--- Online_1 > 0.50 | | | | | | | | | |--- weights: [2.00, 0.00] class: 0 | | | | | | |--- CD_Account_1 > 0.50 | | | | | | | |--- CCAvg <= 4.75 | | | | | | | | |--- weights: [0.00, 2.00] class: 1 | | | | | | | |--- CCAvg > 4.75 | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | |--- Education_3 > 0.50 | | | | | | |--- Experience <= 25.50 | | | | | | | |--- CCAvg <= 3.95 | | | | | | | | |--- metro_Redondo Beach <= 0.50 | | | | | | | | | |--- weights: [0.00, 3.00] class: 1 | | | | | | | | |--- metro_Redondo Beach > 0.50 | | | | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | | | | |--- CCAvg > 3.95 | | | | | | | | |--- weights: [6.00, 0.00] class: 0 | | | | | | |--- Experience > 25.50 | | | | | | | |--- weights: [0.00, 4.00] class: 1 | | | | |--- Education_2 > 0.50 | | | | | |--- weights: [0.00, 4.00] class: 1 | | | |--- Family > 2.50 | | | | |--- Age <= 57.50 | | | | | |--- metro_El Segundo <= 0.50 | | | | | | |--- weights: [0.00, 20.00] class: 1 | | | | | |--- metro_El Segundo > 0.50 | | | | | | |--- weights: [1.00, 0.00] class: 0 | | | | |--- Age > 57.50 | | | | | |--- CD_Account_1 <= 0.50 | | | | | | |--- metro_Oakland <= 0.50 | | | | | | | |--- weights: [7.00, 0.00] class: 0 | | | | | | |--- metro_Oakland > 0.50 | | | | | | | |--- weights: [0.00, 1.00] class: 1 | | | | | |--- CD_Account_1 > 0.50 | | | | | | |--- weights: [0.00, 2.00] class: 1 |--- Income > 116.50 | |--- Family <= 2.50 | | |--- Education_3 <= 0.50 | | | |--- Education_2 <= 0.50 | | | | |--- weights: [375.00, 0.00] class: 0 | | | |--- Education_2 > 0.50 | | | | |--- weights: [0.00, 53.00] class: 1 | | |--- Education_3 > 0.50 | | | |--- weights: [0.00, 62.00] class: 1 | |--- Family > 2.50 | | |--- weights: [0.00, 154.00] class: 1
# importance of features in the tree building ( The importance of a feature is computed as the
# (normalized) total reduction of the 'criterion' brought by that feature. It is also known as the Gini importance )
print(
pd.DataFrame(
best_model.feature_importances_, columns=["Imp"], index=X_train.columns
).sort_values(by="Imp", ascending=False)
)
Imp Income 0.303933 Family 0.248530 Education_2 0.165971 Education_3 0.144207 CCAvg 0.046194 ... ... metro_Huntington Beach 0.000000 metro_Imperial 0.000000 metro_Inglewood 0.000000 metro_Irvine 0.000000 metro_Yucaipa 0.000000 [255 rows x 1 columns]
importances = best_model.feature_importances_
indices = np.argsort(importances)
plt.figure(figsize=(12, 40))
plt.title("Feature Importances")
plt.barh(range(len(indices)), importances[indices], color="violet", align="center")
plt.yticks(range(len(indices)), [feature_names[i] for i in indices])
plt.xlabel("Relative Importance")
plt.show()
# training performance comparison
models_train_comp_df = pd.concat(
[
decision_tree_perf_train.T,
decision_tree_tune_perf_train.T,
decision_tree_postpruned_perf_train.T,
],
axis=1,
)
models_train_comp_df.columns = [
"Decision Tree sklearn",
"Decision Tree (Pre-Pruning)",
"Decision Tree (Post-Pruning)",
]
print("Training performance comparison:")
models_train_comp_df
Training performance comparison:
| Decision Tree sklearn | Decision Tree (Pre-Pruning) | Decision Tree (Post-Pruning) | |
|---|---|---|---|
| Accuracy | 1.0 | 0.985714 | 1.0 |
| Recall | 1.0 | 0.873112 | 1.0 |
| Precision | 1.0 | 0.973064 | 1.0 |
| F1 | 1.0 | 0.920382 | 1.0 |
# test performance comparison
models_train_comp_df = pd.concat(
[
decision_tree_perf_test.T,
decision_tree_tune_perf_test.T,
decision_tree_postpruned_perf_test.T,
],
axis=1,
)
models_train_comp_df.columns = [
"Decision Tree sklearn",
"Decision Tree (Pre-Pruning)",
"Decision Tree (Post-Pruning)",
]
print("Test set performance comparison:")
models_train_comp_df
Test set performance comparison:
| Decision Tree sklearn | Decision Tree (Pre-Pruning) | Decision Tree (Post-Pruning) | |
|---|---|---|---|
| Accuracy | 0.982667 | 0.976667 | 0.982667 |
| Recall | 0.879195 | 0.785235 | 0.879195 |
| Precision | 0.942446 | 0.975000 | 0.942446 |
| F1 | 0.909722 | 0.869888 | 0.909722 |